{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch.utils.data as data\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "from image_helper import *\n",
    "from parse_xml_annotations import *\n",
    "from features import *\n",
    "from reinforcement import *\n",
    "from metrics import *\n",
    "import logging\n",
    "import time\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Image and Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "path_voc = \"../datas/VOCdevkit/VOC2007\"\n",
    "\n",
    "# get models \n",
    "print(\"load models\")\n",
    "\n",
    "model_vgg = getVGG_16bn(\"../models\")\n",
    "model_vgg = model_vgg.cuda()\n",
    "model = get_q_network()\n",
    "model = model.cuda()\n",
    "\n",
    "# define optimizers for each model\n",
    "optimizer = optim.Adam(model.parameters(),lr=1e-6)\n",
    "criterion = nn.MSELoss().cuda()   \n",
    "\n",
    "# get image datas\n",
    "print(\"load images\")\n",
    "\n",
    "path_voc = \"../datas/VOCdevkit/VOC2007\"\n",
    "image_names = np.array(load_images_names_in_data_set('aeroplane_trainval', path_voc))\n",
    "labels = load_images_labels_in_data_set('aeroplane_trainval', path_voc)\n",
    "image_names_aero = []\n",
    "for i in range(len(image_names)):\n",
    "    if labels[i] == '1':\n",
    "        image_names_aero.append(image_names[i])\n",
    "image_names = image_names_aero\n",
    "images = get_all_images(image_names, path_voc)\n",
    "\n",
    "print(\"aeroplane_trainval image:%d\" % len(image_names))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### the replay part should be added in replay.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from collections import namedtuple\n",
    "Transition = namedtuple('Transition',\n",
    "                        ('state', 'action', 'next_state', 'reward'))\n",
    "\n",
    "use_cuda = torch.cuda.is_available()\n",
    "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
    "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n",
    "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n",
    "Tensor = FloatTensor\n",
    "class ReplayMemory(object):\n",
    "\n",
    "    def __init__(self, capacity):\n",
    "        self.capacity = capacity\n",
    "        self.memory = []\n",
    "        self.position = 0\n",
    "\n",
    "    def push(self, *args):\n",
    "        \"\"\"Saves a transition.\"\"\"\n",
    "        if len(self.memory) < self.capacity:\n",
    "            self.memory.append(None)\n",
    "        self.memory[self.position] = Transition(*args)\n",
    "        self.position = (self.position + 1) % self.capacity\n",
    "\n",
    "    def sample(self, batch_size):\n",
    "        return random.sample(self.memory, batch_size)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "epsilon = 0.9\n",
    "BATCH_SIZE = 100\n",
    "GAMMA = 0.90\n",
    "CLASS_OBJECT = 1\n",
    "steps = 10\n",
    "epochs = 25\n",
    "memory = ReplayMemory(1000)\n",
    "\n",
    "def select_action(state):\n",
    "    if random.random() < epsilon:\n",
    "        action = np.random.randint(1,7)\n",
    "    else:\n",
    "        qval = model(Variable(state))\n",
    "        _, predicted = torch.max(qval.data,1)\n",
    "        action = predicted[0] + 1\n",
    "    return action\n",
    "\n",
    "def optimizer_model():\n",
    "    if len(memory) < BATCH_SIZE:\n",
    "        return\n",
    "    transitions = memory.sample(BATCH_SIZE)\n",
    "    batch = Transition(*zip(*transitions))\n",
    "    \n",
    "    non_final_mask = ByteTensor(tuple(map(lambda s: s is not None, batch.next_state)))\n",
    "    next_states = [s for s in batch.next_state if s is not None]\n",
    "    non_final_next_states = Variable(torch.cat(next_states), \n",
    "                                     volatile=True).type(Tensor)\n",
    "    state_batch = Variable(torch.cat(batch.state)).type(Tensor)\n",
    "    action_batch = Variable(torch.LongTensor(batch.action).view(-1,1)).type(LongTensor)\n",
    "    reward_batch = Variable(torch.FloatTensor(batch.reward).view(-1,1)).type(Tensor)\n",
    "\n",
    "    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the\n",
    "    # columns of actions taken\n",
    "    state_action_values = model(state_batch).gather(1, action_batch)\n",
    "    \n",
    "    # Compute V(s_{t+1}) for all next states.\n",
    "    next_state_values = Variable(torch.zeros(BATCH_SIZE, 1).type(Tensor)) \n",
    "    next_state_values[non_final_mask] = model(non_final_next_states).max(1)[0]\n",
    "    \n",
    "    # Now, we don't want to mess up the loss with a volatile flag, so let's\n",
    "    # clear it. After this, we'll just end up with a Variable that has\n",
    "    # requires_grad=False\n",
    "    next_state_values.volatile = False\n",
    "    \n",
    "    # Compute the expected Q values\n",
    "    expected_state_action_values = (next_state_values * GAMMA) + reward_batch\n",
    "    \n",
    "    # Compute  loss\n",
    "    loss = criterion(state_action_values, expected_state_action_values)\n",
    "\n",
    "    # Optimize the model\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# train procedure\n",
    "print('train the Q-network')\n",
    "for epoch in range(epochs):\n",
    "    print('epoch: %d' %epoch)\n",
    "    now = time.time()\n",
    "    for i in range(len(image_names)):\n",
    "        # the image part\n",
    "        image_name = image_names[i]\n",
    "        image = images[i]\n",
    "        annotation = get_bb_of_gt_from_pascal_xml_annotation(image_name, path_voc)\n",
    "        classes_gt_objects = get_ids_objects_from_annotation(annotation)\n",
    "        gt_masks = generate_bounding_box_from_annotation(annotation, image.shape) \n",
    "         \n",
    "        # the iou part\n",
    "        original_shape = (image.shape[0], image.shape[1])\n",
    "        region_mask = np.ones((image.shape[0], image.shape[1]))\n",
    "        #choose the max bouding box\n",
    "        iou = find_max_bounding_box(gt_masks, region_mask, classes_gt_objects, CLASS_OBJECT)\n",
    "        \n",
    "        # the initial part\n",
    "        region_image = image\n",
    "        size_mask = original_shape\n",
    "        offset = (0, 0)\n",
    "        history_vector = torch.zeros((4,6))\n",
    "        state = get_state(region_image, history_vector, model_vgg)\n",
    "        done = False\n",
    "\n",
    "        for step in range(steps):\n",
    "\n",
    "            # Select action, the author force terminal action if case actual IoU is higher than 0.5\n",
    "            if iou > 0.5:\n",
    "                action = 6\n",
    "            else:\n",
    "                action = select_action(state)\n",
    "            \n",
    "            # Perform the action and observe new state\n",
    "            if action == 6:\n",
    "                next_state = None\n",
    "                reward = get_reward_trigger(iou)\n",
    "                done = True\n",
    "            else:\n",
    "                offset, region_image, size_mask, region_mask = get_crop_image_and_mask(original_shape, offset,\n",
    "                                                                   region_image, size_mask, action)\n",
    "                # update history vector and get next state\n",
    "                history_vector = update_history_vector(history_vector, action)\n",
    "                next_state = get_state(region_image, history_vector, model_vgg)\n",
    "                \n",
    "                # find the max bounding box in the region image\n",
    "                new_iou = find_max_bounding_box(gt_masks, region_mask, classes_gt_objects, CLASS_OBJECT)\n",
    "                reward = get_reward_movement(iou, new_iou)\n",
    "                iou = new_iou\n",
    "            print('epoch: %d, image: %d, step: %d, reward: %d' %(epoch ,i, step, reward))    \n",
    "            # Store the transition in memory\n",
    "            memory.push(state, action-1, next_state, reward)\n",
    "            \n",
    "            # Move to the next state\n",
    "            state = next_state\n",
    "            \n",
    "            # Perform one step of the optimization (on the target network)\n",
    "            optimizer_model()\n",
    "            if done:\n",
    "                break\n",
    "    if epsilon > 0.1:\n",
    "        epsilon -= 0.1\n",
    "    time_cost = time.time() - now\n",
    "    print('epoch = %d, time_cost = %.4f' %(epoch, time_cost))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save the Q-Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save the whole model\n",
    "Q_NETWORK_PATH = '../models/' + 'one_object_model_2'\n",
    "torch.save(model, Q_NETWORK_PATH)\n",
    "print('Complete')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
